package org.rockyang.blockj.base.crypto;

import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.generators.PKCS5S2ParametersGenerator;
import org.bouncycastle.crypto.generators.SCrypt;
import org.bouncycastle.crypto.params.KeyParameter;
import org.rockyang.blockj.base.exceptions.CipherException;
import org.rockyang.blockj.base.utils.Numeric;

import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.UUID;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.rockyang.blockj.base.crypto.SecureRandomUtils.secureRandom;

/**
 * <p>Ethereum wallet file management. For reference, refer to
 * <a href="https://github.com/ethereum/wiki/wiki/Web3-Secret-Storage-Definition">
 * Web3 Secret Storage Definition</a> or the
 * <a href="https://github.com/ethereum/go-ethereum/blob/master/accounts/key_store_passphrase.go">
 * Go Ethereum client implementation</a>.</p>
 *
 * <p><strong>Note:</strong> the Bouncy Castle Scrypt implementation
 * {@link SCrypt}, fails to comply with the following
 * Ethereum reference
 * <a href="https://github.com/ethereum/wiki/wiki/Web3-Secret-Storage-Definition#scrypt">
 * Scrypt test vector</a>:</p>
 *
 * <pre>
 * {@code
 * // Only value of r that cost (as an int) could be exceeded for is 1
 * if (r == 1 && N_STANDARD > 65536)
 * {
 *     throw new IllegalArgumentException("Cost parameter N_STANDARD must be > 1 and < 65536.");
 * }
 * }
 * </pre>
 */
public class Wallet {

	private static final int N_LIGHT = 1 << 12;
	private static final int P_LIGHT = 6;

	private static final int N_STANDARD = 1 << 18;
	private static final int P_STANDARD = 1;

	private static final int R = 8;
	private static final int DKLEN = 32;

	private static final int CURRENT_VERSION = 3;

	private static final String CIPHER = "aes-128-ctr";
	static final String AES_128_CTR = "pbkdf2";
	static final String SCRYPT = "scrypt";

	public static WalletFile create(String password, ECKeyPair ecKeyPair, int n, int p)
			throws CipherException
	{

		byte[] salt = generateRandomBytes(32);

		byte[] derivedKey = generateDerivedScryptKey(
				password.getBytes(UTF_8), salt, n, R, p, DKLEN);

		byte[] encryptKey = Arrays.copyOfRange(derivedKey, 0, 16);
		byte[] iv = generateRandomBytes(16);

		byte[] privateKeyBytes =
				Numeric.toBytesPadded(ecKeyPair.getPrivateKeyValue(), Keys.PRIVATE_KEY_SIZE);

		byte[] cipherText = performCipherOperation(
				Cipher.ENCRYPT_MODE, iv, encryptKey, privateKeyBytes);

		byte[] mac = generateMac(derivedKey, cipherText);

		return createWalletFile(ecKeyPair, cipherText, iv, salt, mac, n, p);
	}

	public static WalletFile createStandard(String password, ECKeyPair ecKeyPair)
			throws CipherException
	{
		return create(password, ecKeyPair, N_STANDARD, P_STANDARD);
	}

	public static WalletFile createLight(String password, ECKeyPair ecKeyPair)
			throws CipherException
	{
		return create(password, ecKeyPair, N_LIGHT, P_LIGHT);
	}

	private static WalletFile createWalletFile(
			ECKeyPair ecKeyPair, byte[] cipherText, byte[] iv, byte[] salt, byte[] mac,
			int n, int p)
	{

		WalletFile walletFile = new WalletFile();
		walletFile.setAddress(Keys.getAddress(ecKeyPair));

		WalletFile.Crypto crypto = new WalletFile.Crypto();
		crypto.setCipher(CIPHER);
		crypto.setCiphertext(Numeric.toHexStringNoPrefix(cipherText));

		WalletFile.CipherParams cipherParams = new WalletFile.CipherParams();
		cipherParams.setIv(Numeric.toHexStringNoPrefix(iv));
		crypto.setCipherparams(cipherParams);

		crypto.setKdf(SCRYPT);
		WalletFile.ScryptKdfParams kdfParams = new WalletFile.ScryptKdfParams();
		kdfParams.setDklen(DKLEN);
		kdfParams.setN(n);
		kdfParams.setP(p);
		kdfParams.setR(R);
		kdfParams.setSalt(Numeric.toHexStringNoPrefix(salt));
		crypto.setKdfparams(kdfParams);

		crypto.setMac(Numeric.toHexStringNoPrefix(mac));
		walletFile.setCrypto(crypto);
		walletFile.setId(UUID.randomUUID().toString());
		walletFile.setVersion(CURRENT_VERSION);

		return walletFile;
	}

	private static byte[] generateDerivedScryptKey(
			byte[] password, byte[] salt, int n, int r, int p, int dkLen) throws CipherException
	{
		return SCrypt.generate(password, salt, n, r, p, dkLen);
	}

	private static byte[] generateAes128CtrDerivedKey(
			byte[] password, byte[] salt, int c, String prf) throws CipherException
	{

		if (!prf.equals("hmac-sha256")) {
			throw new CipherException("Unsupported prf:" + prf);
		}

		// Java 8 supports this, but you have to convert the password to a character array, see
		// http://stackoverflow.com/a/27928435/3211687

		PKCS5S2ParametersGenerator gen = new PKCS5S2ParametersGenerator(new SHA256Digest());
		gen.init(password, salt, c);
		return ((KeyParameter) gen.generateDerivedParameters(256)).getKey();
	}

	private static byte[] performCipherOperation(
			int mode, byte[] iv, byte[] encryptKey, byte[] text) throws CipherException
	{

		try {
			IvParameterSpec ivParameterSpec = new IvParameterSpec(iv);
			Cipher cipher = Cipher.getInstance("AES/CTR/NoPadding");

			SecretKeySpec secretKeySpec = new SecretKeySpec(encryptKey, "AES");
			cipher.init(mode, secretKeySpec, ivParameterSpec);
			return cipher.doFinal(text);
		} catch (NoSuchPaddingException | NoSuchAlgorithmException
				| InvalidAlgorithmParameterException | InvalidKeyException
				| BadPaddingException | IllegalBlockSizeException e) {
			throw new CipherException("Error performing cipher operation", e);
		}
	}

	private static byte[] generateMac(byte[] derivedKey, byte[] cipherText)
	{
		byte[] result = new byte[16 + cipherText.length];

		System.arraycopy(derivedKey, 16, result, 0, 16);
		System.arraycopy(cipherText, 0, result, 16, cipherText.length);

		return Hash.sha3(result);
	}

	public static ECKeyPair decrypt(String password, WalletFile walletFile)
			throws Exception
	{

		validate(walletFile);

		WalletFile.Crypto crypto = walletFile.getCrypto();

		byte[] mac = Numeric.hexStringToByteArray(crypto.getMac());
		byte[] iv = Numeric.hexStringToByteArray(crypto.getCipherparams().getIv());
		byte[] cipherText = Numeric.hexStringToByteArray(crypto.getCiphertext());

		byte[] derivedKey;

		WalletFile.KdfParams kdfParams = crypto.getKdfparams();
		if (kdfParams instanceof WalletFile.ScryptKdfParams) {
			WalletFile.ScryptKdfParams scryptKdfParams =
					(WalletFile.ScryptKdfParams) crypto.getKdfparams();
			int dklen = scryptKdfParams.getDklen();
			int n = scryptKdfParams.getN();
			int p = scryptKdfParams.getP();
			int r = scryptKdfParams.getR();
			byte[] salt = Numeric.hexStringToByteArray(scryptKdfParams.getSalt());
			derivedKey = generateDerivedScryptKey(password.getBytes(UTF_8), salt, n, r, p, dklen);
		} else if (kdfParams instanceof WalletFile.Aes128CtrKdfParams) {
			WalletFile.Aes128CtrKdfParams aes128CtrKdfParams =
					(WalletFile.Aes128CtrKdfParams) crypto.getKdfparams();
			int c = aes128CtrKdfParams.getC();
			String prf = aes128CtrKdfParams.getPrf();
			byte[] salt = Numeric.hexStringToByteArray(aes128CtrKdfParams.getSalt());

			derivedKey = generateAes128CtrDerivedKey(password.getBytes(UTF_8), salt, c, prf);
		} else {
			throw new CipherException("Unable to deserialize params: " + crypto.getKdf());
		}

		byte[] derivedMac = generateMac(derivedKey, cipherText);

		if (!Arrays.equals(derivedMac, mac)) {
			throw new CipherException("Invalid password provided");
		}

		byte[] encryptKey = Arrays.copyOfRange(derivedKey, 0, 16);
		byte[] privateKey = performCipherOperation(Cipher.DECRYPT_MODE, iv, encryptKey, cipherText);
		return ECKeyPair.create(privateKey);
	}

	static void validate(WalletFile walletFile) throws CipherException
	{
		WalletFile.Crypto crypto = walletFile.getCrypto();

		if (walletFile.getVersion() != CURRENT_VERSION) {
			throw new CipherException("Wallet version is not supported");
		}

		if (!crypto.getCipher().equals(CIPHER)) {
			throw new CipherException("Wallet cipher is not supported");
		}

		if (!crypto.getKdf().equals(AES_128_CTR) && !crypto.getKdf().equals(SCRYPT)) {
			throw new CipherException("KDF type is not supported");
		}
	}

	static byte[] generateRandomBytes(int size)
	{
		byte[] bytes = new byte[size];
		secureRandom().nextBytes(bytes);
		return bytes;
	}
}
